{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !unzip uncased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf data/.ipynb_checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/usr/local/lib/python3.6/dist-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "from utils import *\n",
    "import tensorflow as tf\n",
    "from sklearn.model_selection import train_test_split\n",
    "import time\n",
    "import random\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['negative', 'positive']\n",
      "10662\n",
      "10662\n"
     ]
    }
   ],
   "source": [
    "trainset = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
    "trainset.data, trainset.target = separate_dataset(trainset,1.0)\n",
    "print (trainset.target_names)\n",
    "print (len(trainset.data))\n",
    "print (len(trainset.target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
    "BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
    "BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0728 12:08:01.512709 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0728 12:08:01.526020 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file=BERT_VOCAB, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LENGTH = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10662/10662 [00:02<00:00, 4605.25it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "input_ids, input_masks, segment_ids = [], [], []\n",
    "\n",
    "for text in tqdm(trainset.data):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "        tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "    tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
    "    segment_id = [0] * len(tokens)\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    padding = [0] * (MAX_SEQ_LENGTH - len(input_id))\n",
    "    input_id += padding\n",
    "    input_mask += padding\n",
    "    segment_id += padding\n",
    "    \n",
    "    input_ids.append(input_id)\n",
    "    input_masks.append(input_mask)\n",
    "    segment_ids.append(segment_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['worse', 'film', 'breaking', 'breaking', 'utterly', 'charming']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.tokenize(trainset.data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 60\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(input_ids) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            config=bert_config,\n",
    "            is_training=True,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_pooled_output()\n",
    "        self.logits = tf.layers.dense(output_layer, dimension_output)\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0728 12:08:04.883535 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/modeling.py:171: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "W0728 12:08:04.888481 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/modeling.py:409: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "W0728 12:08:04.932405 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/modeling.py:490: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
      "\n",
      "W0728 12:08:05.471327 139862809098048 lazy_loader.py:50] \n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "W0728 12:08:05.490773 139862809098048 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/bert/modeling.py:358: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "W0728 12:08:05.533877 139862809098048 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/bert/modeling.py:671: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "W0728 12:08:07.983676 139862809098048 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "W0728 12:08:08.000271 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/optimization.py:27: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "W0728 12:08:08.005468 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/optimization.py:32: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "W0728 12:08:08.011108 139862809098048 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py:409: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Deprecated in favor of operator or tf.math.divide.\n",
      "W0728 12:08:08.023190 139862809098048 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/bert/optimization.py:70: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n",
      "\n",
      "W0728 12:08:08.182476 139862809098048 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_attention_weights(tf_graph):\n",
    "    num_layers = bert_config.num_hidden_layers\n",
    "    attns = [{'layer_%s' % i: tf_graph.get_tensor_by_name('bert/encoder/layer_%s/attention/self/Softmax:0' % i)}\n",
    "             for i in range(num_layers)]\n",
    "\n",
    "    return attns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "attns = extract_attention_weights(tf.get_default_graph())\n",
    "model.attns = attns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0728 12:08:20.580060 139862809098048 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n"
     ]
    }
   ],
   "source": [
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_input_ids, test_input_ids, train_input_masks, test_input_masks, train_segment_ids, test_segment_ids, train_Y, test_Y = train_test_split(\n",
    "    input_ids, input_masks, segment_ids, trainset.target, test_size = 0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 143/143 [03:07<00:00,  1.01it/s, accuracy=0.556, cost=0.742]\n",
      "test minibatch loop: 100%|██████████| 36/36 [00:17<00:00,  2.40it/s, accuracy=0.848, cost=0.383]\n",
      "train minibatch loop:   0%|          | 0/143 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.813835\n",
      "time taken: 204.77616906166077\n",
      "epoch: 0, training loss: 0.546556, training acc: 0.722046, valid loss: 0.449890, valid acc: 0.813835\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 143/143 [03:03<00:00,  1.01it/s, accuracy=0.889, cost=0.137]\n",
      "test minibatch loop: 100%|██████████| 36/36 [00:17<00:00,  2.40it/s, accuracy=0.818, cost=0.486]\n",
      "train minibatch loop:   0%|          | 0/143 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 200.3038547039032\n",
      "epoch: 1, training loss: 0.365002, training acc: 0.846914, valid loss: 0.483403, valid acc: 0.810169\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 143/143 [03:03<00:00,  1.01it/s, accuracy=1, cost=0.00557]   \n",
      "test minibatch loop: 100%|██████████| 36/36 [00:17<00:00,  2.40it/s, accuracy=0.758, cost=0.64] \n",
      "train minibatch loop:   0%|          | 0/143 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 200.25593400001526\n",
      "epoch: 2, training loss: 0.201564, training acc: 0.931176, valid loss: 0.805201, valid acc: 0.776584\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 143/143 [03:03<00:00,  1.01it/s, accuracy=1, cost=0.00301]   \n",
      "test minibatch loop: 100%|██████████| 36/36 [00:17<00:00,  2.40it/s, accuracy=0.727, cost=1.48] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 200.3201220035553\n",
      "epoch: 3, training loss: 0.116890, training acc: 0.966702, valid loss: 0.900214, valid acc: 0.802924\n",
      "\n",
      "break epoch:4\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = train_input_ids[i: index]\n",
    "        batch_masks = train_input_masks[i: index]\n",
    "        batch_segment = train_segment_ids[i: index]\n",
    "        batch_y = train_Y[i: index]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = test_input_ids[i: index]\n",
    "        batch_masks = test_input_masks[i: index]\n",
    "        batch_segment = test_segment_ids[i: index]\n",
    "        batch_y = test_Y[i: index]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss /= len(train_input_ids) / batch_size\n",
    "    train_acc /= len(train_input_ids) / batch_size\n",
    "    test_loss /= len(test_input_ids) / batch_size\n",
    "    test_acc /= len(test_input_ids) / batch_size\n",
    "\n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'time plot grinds increasingly incoherent fashion might wishing watch makes time go faster rather way around'"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = [input_ids[-1]]\n",
    "batch_segment = [segment_ids[-1]]\n",
    "batch_masks = [input_masks[-1]]\n",
    "trainset.data[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "attentions = sess.run(model.attns, feed_dict = {model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_attn = attentions[0]['layer_0'][:, :, 0, :]\n",
    "cls_attn = np.mean(cls_attn, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_weights = np.sum(cls_attn, axis=-1, keepdims=True)\n",
    "attn = (cls_attn / total_weights)[0]\n",
    "attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_wordpiece_tokens(paired_tokens):\n",
    "    \"\"\"\n",
    "    Combine tokens that have been broken up during the wordpiece tokenization process\n",
    "    The weights of the merged token is the average of their wordpieces\n",
    "    :param paired_tokens: A list of tuples with the form (wordpiece_token, attention_weight)\n",
    "    :return: A list of tuples with the form (word_token, attention_weight)\n",
    "    \"\"\"\n",
    "    new_paired_tokens = []\n",
    "    n_tokens = len(paired_tokens)\n",
    "\n",
    "    i = 0\n",
    "\n",
    "    while i < n_tokens:\n",
    "        current_token, current_weight = paired_tokens[i]\n",
    "        if current_token.startswith('##'):\n",
    "            previous_token, previous_weight = new_paired_tokens.pop()\n",
    "            merged_token = previous_token\n",
    "            merged_weight = [previous_weight]\n",
    "            while current_token.startswith('##'):\n",
    "                merged_token = merged_token + current_token.replace('##', '')\n",
    "                merged_weight.append(current_weight)\n",
    "                i = i + 1\n",
    "                current_token, current_weight = paired_tokens[i]\n",
    "            merged_weight = np.mean(merged_weight)\n",
    "            new_paired_tokens.append((merged_token, merged_weight))\n",
    "\n",
    "        else:\n",
    "            new_paired_tokens.append((current_token, current_weight))\n",
    "            i = i + 1\n",
    "    \n",
    "    words = [i[0] for i in new_paired_tokens if i[0] not in ['[CLS]', '[SEP]', '[PAD]']]\n",
    "    weights = [i[1] for i in new_paired_tokens if i[0] not in ['[CLS]', '[SEP]', '[PAD]']]\n",
    "    weights = np.array(weights)\n",
    "    weights = weights / np.sum(weights)\n",
    "    return list(zip(words, weights))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('time', 0.07921354),\n",
       " ('plot', 0.05070206),\n",
       " ('grinds', 0.073813565),\n",
       " ('increasingly', 0.057528302),\n",
       " ('incoherent', 0.06072644),\n",
       " ('fashion', 0.07929965),\n",
       " ('might', 0.039253872),\n",
       " ('wishing', 0.06646434),\n",
       " ('watch', 0.04416551),\n",
       " ('makes', 0.038636986),\n",
       " ('time', 0.09167157),\n",
       " ('go', 0.09118808),\n",
       " ('faster', 0.051011782),\n",
       " ('rather', 0.055073477),\n",
       " ('way', 0.06224899),\n",
       " ('around', 0.05900185)]"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merge_wordpiece_tokens(list(zip(tokenizer.convert_ids_to_tokens(batch_x[0]), attn)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cls_attn = attentions[-1]['layer_11'][:, :, 0, :]\n",
    "cls_attn = np.mean(cls_attn, axis=1)\n",
    "total_weights = np.sum(cls_attn, axis=-1, keepdims=True)\n",
    "attn = (cls_attn / total_weights)[0]\n",
    "attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('time', 0.056819078),\n",
       " ('plot', 0.0632532),\n",
       " ('grinds', 0.06938076),\n",
       " ('increasingly', 0.08377443),\n",
       " ('incoherent', 0.024481593),\n",
       " ('fashion', 0.02720975),\n",
       " ('might', 0.026110101),\n",
       " ('wishing', 0.015731115),\n",
       " ('watch', 0.08168129),\n",
       " ('makes', 0.08993647),\n",
       " ('time', 0.09198012),\n",
       " ('go', 0.029810423),\n",
       " ('faster', 0.1186272),\n",
       " ('rather', 0.05931811),\n",
       " ('way', 0.05009796),\n",
       " ('around', 0.111788414)]"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merge_wordpiece_tokens(list(zip(tokenizer.convert_ids_to_tokens(batch_x[0]), attn)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "combined_attentions = []\n",
    "for a in attentions:\n",
    "    combined_attentions.append(list(a.values())[0])\n",
    "    \n",
    "cls_attn = np.mean(combined_attentions, axis = 0).mean(axis = 2)\n",
    "cls_attn = np.mean(cls_attn, axis=1)\n",
    "total_weights = np.sum(cls_attn, axis=-1, keepdims=True)\n",
    "attn = (cls_attn / total_weights)[0]\n",
    "attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('time', 0.053206075),\n",
       " ('plot', 0.058655694),\n",
       " ('grinds', 0.058348235),\n",
       " ('increasingly', 0.056193475),\n",
       " ('incoherent', 0.04362384),\n",
       " ('fashion', 0.0641345),\n",
       " ('might', 0.073644035),\n",
       " ('wishing', 0.07437738),\n",
       " ('watch', 0.06467828),\n",
       " ('makes', 0.07696252),\n",
       " ('time', 0.054407526),\n",
       " ('go', 0.04876398),\n",
       " ('faster', 0.07310601),\n",
       " ('rather', 0.07628402),\n",
       " ('way', 0.045983776),\n",
       " ('around', 0.07763063)]"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merge_wordpiece_tokens(list(zip(tokenizer.convert_ids_to_tokens(batch_x[0]), attn)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
